import torch, numpy as np, scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F

def _to_dense_torch(mat, device):
    import numpy as np, scipy.sparse as sp, torch
    if isinstance(mat, np.ndarray):
        arr = mat
    elif sp.issparse(mat):
        arr = mat.toarray()
    else:
        arr = np.asarray(mat)
    return torch.as_tensor(arr, dtype=torch.float32, device=device)

def unpool_one_level(H_coarse, clusters, N_fine):
    device = H_coarse.device
    D = H_coarse.size(1)
    H_fine = torch.zeros(N_fine, D, device=device)
    for i, child_idx in enumerate(clusters):
        if len(child_idx) == 0: continue
        idx = torch.as_tensor(child_idx, dtype=torch.long, device=device)
        H_fine.index_add_(0, idx, H_coarse[i].expand(idx.numel(), D))
    return H_fine

def unpool_to_level0(H_l, level_l, treeG):
    H = H_l
    for m in range(level_l, 0, -1):
        clusters_m = treeG[m]['clusters']
        N_fine     = treeG[m-1]['adj'].shape[0]
        H = unpool_one_level(H, clusters_m, N_fine)
    return H

class HaarSpectralBlock(nn.Module):
    def __init__(self, max_K: int):
        super().__init__()
        self.lambda_vec = nn.Parameter(torch.randn(max_K))

    def forward(self, U: torch.Tensor, X: torch.Tensor):
        K_l   = U.size(1)
        K_cap = min(K_l, self.lambda_vec.size(0))
        Uc    = U[:, :K_cap]
        X_hat = Uc.transpose(0, 1) @ X
        lam   = self.lambda_vec[:K_cap].unsqueeze(1)
        X_hat = X_hat * lam
        H     = Uc @ X_hat
        return F.relu(H)

class NodeHaarUnpoolClassifier(nn.Module):
    def __init__(self, in_dim: int, hid_dim: int, num_classes: int, max_K: int, num_levels: int):
        super().__init__()
        self.num_levels = num_levels
        self.pre = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim)
        )
        self.block = HaarSpectralBlock(max_K=max_K)
        self.classifier = nn.Linear(hid_dim * num_levels, num_classes)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, U_list, features_list, treeG):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        L_eff = min(self.num_levels, len(U_list))
        H_per_level = []
        for l in range(L_eff):
            X_l = _to_dense_torch(features_list[l], device)
            X_l = self.dropout(self.pre(X_l))
            if U_list[l] is None:
                import numpy as np
                u = treeG[l]['u']
                N_l = len(u); K_l = len(treeG[l+1]['u']) if l < len(treeG)-1 else 1
                U_np = np.zeros((N_l, K_l), dtype=np.float32)
                for k in range(K_l):
                    U_np[:, k] = u[k]
                U_l = _to_dense_torch(U_np, device)
            else:
                U_l = _to_dense_torch(U_list[l], device)
            H_l = self.block(U_l, X_l)
            H0_l = unpool_to_level0(H_l, level_l=l, treeG=treeG)
            H_per_level.append(H0_l)
        H0_cat = torch.cat(H_per_level, dim=1)
        H0_cat = self.dropout(H0_cat)
        logits = self.classifier(H0_cat)
        return logits

